from typing import Dict
import subprocess
from collections import defaultdict
from prompta.utils.java_libs import CompactDFA, LearnLibV2Serialization, SAFSerializationDFA, TAFSerializationDFA, FileInputStream, FileOutputStream, Visualization, VPManager, GraphVizBrowserVisualizationProvider, Word
from jpype import JString, JArray



vpManager = VPManager()
vpManager.load()

def show_graph(graph, disable_graphics=False) -> None:
    browserProvider = vpManager.getProviderById(GraphVizBrowserVisualizationProvider.ID)
    Visualization.visualize(graph, False)
    

def show_dfa(aut):
    show_graph(aut.graphView())

    
def load_dfa(filename):
    filetype = filename.split('.')[-1]
    istream = FileInputStream(filename)
    if filetype == 'dfa':
        reader = LearnLibV2Serialization.getInstance()
        model = reader.readModel(istream)
    elif filetype == 'saf':
        reader = SAFSerializationDFA.getInstance()
        model = reader.readModel(istream)
    elif filetype == 'taf':
        reader = TAFSerializationDFA.getInstance()
        model = reader.readModel(istream)
    else:
        raise NotImplementedError
        
    return model


def save_dfa(filename, automaton, alphabet):
    filetype = filename.split('.')[-1]
    ostream = FileOutputStream(filename)
    if filetype == 'dfa':
        writer = LearnLibV2Serialization.getInstance()
        writer.writeModel(ostream, automaton, alphabet)
    elif filetype == 'saf':
        writer = SAFSerializationDFA.getInstance()
        writer.writeModel(ostream, automaton, alphabet)
    elif filetype == 'taf':
        writer = TAFSerializationDFA.getInstance()
        writer.writeModel(ostream, automaton, alphabet)
    else:
        raise NotImplementedError


def query2str(query):
    if query:
        word = query.getInput()
        return (str(_).replace(' ', '').replace('ε', '') for _ in word)
    return None

def word2tuple(word):
    if word:
        return tuple([str(_) for _ in word])
    return tuple()

def tuple2word(word):
    tlen = len(word)
    return Word.fromArray(JArray(JString)(list(word)), 0, tlen)

def pta_dfa2compact_dfa(pta_dfa):
    states = pta_dfa.getStates()
    alphabet = pta_dfa.getInputAlphabet()
    compact_dfa = CompactDFA(alphabet)
    state_mapping = {str(pta_dfa.getInitialState()): compact_dfa.addInitialState()}
    compact_dfa.setAccepting(state_mapping[str(pta_dfa.getInitialState())], pta_dfa.isAccepting(pta_dfa.getInitialState()))
    
    for s in states:
        if str(s) not in state_mapping:
            new_state = compact_dfa.addState(pta_dfa.isAccepting(s))
            state_mapping[str(s)] = new_state

    for s in states:
        for a in alphabet:
            target_s = pta_dfa.getTransition(s, a)
            target_s = target_s if target_s is not None else s
            compact_dfa.addTransition(state_mapping[str(s)], a, state_mapping[str(target_s)])

    return compact_dfa

str2bool = defaultdict(lambda : False)
str2bool.update({
        True: True,
        False: False,
        "true": True,
        "True": True,
        "false": False,
        "False": False,
    })

json_keys = {
    'answer': ['answer', 'Answer', 'belongs_to_language'],
    'reason': ['reason', 'Reason']
}


def get_value_by_key(result: Dict, key: str, is_boolean=False):
    for k in json_keys[key]:
        v = result.get(k, None)
        if v:
            if is_boolean:
                return str2bool[v]
            return v

    if is_boolean:
        for k, v in result.items():
            return str2bool[v]

    print(result)
    print(UserWarning("Key answer not found!"))
    return None